package com.congee02.bt.dfs;

import com.congee02.bt.TreeNode;

public class PruneTree {

    public TreeNode pruneTree(TreeNode root) {

        if (! containsOne(root)) {
            return null;
        }
        return root;

    }

    private boolean containsOne(TreeNode root) {

        if (root == null) {
            return false;
        }

        boolean lr = containsOne(root.left);
        boolean rr = containsOne(root.right);

        if (! lr) {
            root.left = null;
        }

        if (! rr) {
            root.right = null;
        }

        return root.val == 1 || lr || rr;
    }

}
